Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| if not self.experimental: | ||
| raise ValueError( | ||
| "NeuronGRPOTrainer is experimental and not production-ready. To proceed, set `experimental=True` in " | ||
| "your NeuronGRPOConfig. This flag exists to ensure users are aware of the current state of the implementation." | ||
| ) |
There was a problem hiding this comment.
For now we disable the access to the NeuronGRPOTrainer
dacorvo
left a comment
There was a problem hiding this comment.
The change in the CI workflow aside, this looks good to me, although I did not go into the details of the trainer algorithm.
For the next step, you will need to add a load_state_dict method in NxDPretrainedModel, as the existing load_weights method reads weights from a path. You will need to provide also modified checkpoint_loader_fn to load weights from a state_dict directly as for now it also expects to load the state_dict from a path.
Then in vLLM you will need to call the load_state_dict method when required.
| run: | | ||
| source aws_neuron_venv_pytorch/bin/activate | ||
| python -m pip install .[neuronx,tests] | ||
| python -m pip install .[neuronx,tests,training] |
There was a problem hiding this comment.
You should not install training requirements for all workflows:
- this can create conflicts
- this will hide any imports errors
There was a problem hiding this comment.
Pull request overview
This PR adds partial support for GRPO (Group Relative Policy Optimization) training on Neuron (Trainium) devices through the new NeuronGRPOTrainer class. The implementation includes XLA-specific optimizations and modifications to work with the Torch XLA backend, though several core features remain unimplemented (vLLM integration, weight synchronization, tensor parallelism).
Changes:
- Adds
NeuronGRPOTrainerwith XLA-optimized implementations for generation, scoring, and loss computation - Introduces
NeuronGRPOConfigfor configuration with experimental flag requirement - Implements XLA-friendly utility functions (padding, entropy, statistical operations) in
trl_utils.py - Adds custom vLLM client implementations with CPU communicator and mock client for testing
- Updates
NeuronTrainerto support_prepare_inputshook and replacesxm.mark_step()withtorch_xla.sync() - Modifies LoRA transformation utilities to handle missing weights more gracefully
Reviewed changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 19 comments.
Show a summary per file
| File | Description |
|---|---|
| optimum/neuron/trainers/grpo_trainer.py | Core GRPO trainer implementation with XLA optimizations (1414 lines, new file) |
| optimum/neuron/trainers/grpo_config.py | Configuration class with validation and experimental flag (118 lines, new file) |
| optimum/neuron/trainers/trl_utils.py | XLA-optimized utility functions for padding, statistics, and sampling (270 lines) |
| optimum/neuron/trainers/extras/vllm_client.py | Custom vLLM clients for Neuron with CPU communicator and mock implementation (213 lines, new file) |
| optimum/neuron/trainers/transformers.py | Updates to NeuronTrainer for _prepare_inputs hook and torch_xla.sync() migration |
| optimum/neuron/trainers/utils.py | Adds move_inputs_to_device utility and updates XLAPrefetchIterator |
| optimum/neuron/models/training/transformations_utils.py | Converts LoRA weight errors to silent skips for flexibility |
| optimum/neuron/trainers/metrics/collector.py | Refactors get_metric_unit for cleaner logic |
| optimum/neuron/utils/init.py | Exports is_vllm_available function |
| optimum/neuron/init.py | Exports NeuronGRPOTrainer and NeuronGRPOConfig |
| .github/actions/install_optimum_neuron/action.yml | Adds training extras to CI installation |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| raise Exception(f"Request failed: {response.status_code}, {response.text}") | ||
|
|
||
| world_size = vllm_world_size + 1 # add the client to the world | ||
| self.rank = vllm_world_size # the client's rank is the last process | ||
|
|
||
| # Initialize weight update group | ||
| url = f"{self.base_url}/init_communicator/" | ||
|
|
||
| # Use dummy UUID for CPU/Neuron environments | ||
| client_device_uuid = "42" | ||
|
|
||
| # In the server side, the host is set to 0.0.0.0 | ||
| response = self.session.post( | ||
| url, | ||
| json={ | ||
| "host": "0.0.0.0", | ||
| "port": self.group_port, | ||
| "world_size": world_size, | ||
| "client_device_uuid": client_device_uuid, | ||
| }, | ||
| ) | ||
| if response.status_code != 200: | ||
| raise Exception(f"Request failed: {response.status_code}, {response.text}") |
There was a problem hiding this comment.
Using a bare Exception type is not recommended. Use a more specific exception type like RuntimeError or create a custom exception class for better error handling and debugging. This applies to both lines 90 and 112.
| def __init__(self, tokenizer, max_completion_length=256, min_completion_length=10, seed=None): | ||
| self.tokenizer = tokenizer | ||
| self.max_completion_length = max_completion_length | ||
| self.min_completion_length = min(min_completion_length, max_completion_length) | ||
| self.random = random.Random(seed) | ||
|
|
||
| logger.warning( | ||
| "Using MockVLLMClient for neuron_parallel_compile or testing. " | ||
| "This generates echo completions and should only be used for compilation/testing." | ||
| ) |
There was a problem hiding this comment.
MockVLLMClient inherits from VLLMClient but doesn't call super().__init__(). This means parent class initialization is skipped, which could cause issues if the parent class (TRLVLLMClient via VLLMClient) expects certain attributes to be initialized. The parent class likely sets up self.session, self.base_url, self.host, self.group_port and other attributes that may be accessed. Consider either calling super().__init__() with appropriate parameters or inheriting directly from object if the parent's initialization is not needed.
| Compute the minimum value of a tensor, ignoring NaNs. | ||
| """ | ||
| mask = torch.isnan(tensor) | ||
| filled = torch.where(mask, torch.tensor(float("inf"), device=tensor.device), tensor) |
There was a problem hiding this comment.
Creating a new tensor with torch.tensor(float("inf"), device=tensor.device) for each call can cause XLA graph fragmentation. Consider pre-creating constant tensors during initialization (similar to _one_float, _inf_float in the trainer) and reusing them, or use filled.new_full((1,), float("inf"))[0] which reuses the existing tensor's properties.
| num_items_in_batch, | ||
| sampling_per_token_logps_list, | ||
| forward_kwargs, | ||
| ) = self._generate(prompts, images) |
There was a problem hiding this comment.
The method _generate is called but not defined in this class. This will cause an AttributeError at runtime. The method should either be defined in this class or inherited from the parent GRPOTrainer class via the _GRPOTrainer intermediate class. Based on the usage, it should return a tuple of (prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list, forward_kwargs).
|
|
||
| # Gradient accumulation requires scaled loss | ||
| self.model_accepts_loss_kwargs = False | ||
|
|
There was a problem hiding this comment.
The _tag_names attribute is accessed but never defined in this class. This will cause an AttributeError at runtime if the model has the add_model_tags method. This attribute should be defined during initialization or inherited from a parent class.
| # Ensure _tag_names exists before being passed to model.add_model_tags. | |
| if not hasattr(self, "_tag_names"): | |
| self._tag_names = set() |
| prompt_ids.append(prompt_tokens) | ||
|
|
||
| # Generate n completions per prompt | ||
| for _ in range(n): | ||
| # Random completion length within bounds | ||
| max_len = min(max_tokens, self.max_completion_length) | ||
| completion_length = self.random.randint(self.min_completion_length, max_len) | ||
|
|
||
| # Echo mode: cycle through prompt tokens | ||
| if len(prompt_tokens) > 0: | ||
| completion = [prompt_tokens[i % len(prompt_tokens)] for i in range(completion_length)] | ||
| else: | ||
| # Fallback if prompt is empty | ||
| completion = [fallback_token_id] * completion_length | ||
|
|
||
| completion_ids.append(completion) | ||
|
|
||
| # Logprobs: simulate higher confidence for echoed tokens | ||
| completion_logprobs = [-self.random.uniform(0.5, 2.0) for _ in range(completion_length)] | ||
| logprobs.append(completion_logprobs) | ||
|
|
There was a problem hiding this comment.
The MockVLLMClient generates n completions for each prompt, but only appends each prompt to prompt_ids once (line 176). This means prompt_ids will have length equal to the number of prompts, while completion_ids will have length equal to num_prompts * n. This mismatch in list lengths could cause issues if the caller expects both lists to have the same length. Based on the usage in _generate_single_turn, it appears the expected behavior is for prompt_ids to be repeated for each completion.
| # Send weights to vLLM server (only main process for server mode) | ||
| for name, weight in original_weights.items(): | ||
| # Clean up parameter name for vLLM | ||
| name = self._fix_param_name_to_vllm(name) |
There was a problem hiding this comment.
The method _fix_param_name_to_vllm is called but not defined in this class or any visible parent class. This will cause an AttributeError when executing this code path. The method should be implemented or inherited from a parent class. Based on the context, it appears this method should clean up parameter names for vLLM compatibility.
| if to_concat_and_duplicate_name is None or to_unfuse_name is None: | ||
| raise ValueError( | ||
| f"Could not find LoRA weights for {module_fully_qualified_name} with param name {param_name}." | ||
| ) | ||
| continue |
There was a problem hiding this comment.
Similar to the previous issue, this converts a hard error into a silent skip. This could hide configuration problems. Consider logging when weights are not found to aid debugging.
| # TODO: Currently not supported, to implement asap in later PRs with vLLM integration. | ||
| # if self.vllm_mode == "server" and self.accelerator.is_main_process: | ||
| # self.vllm_client.update_named_param(name, weight) | ||
| # elif self.vllm_mode == "colocate": | ||
| # llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model | ||
| # llm_model.load_weights([(name, weight)]) |
There was a problem hiding this comment.
This comment appears to contain commented-out code.
| # TODO: Currently not supported, to implement asap in later PRs with vLLM integration. | |
| # if self.vllm_mode == "server" and self.accelerator.is_main_process: | |
| # self.vllm_client.update_named_param(name, weight) | |
| # elif self.vllm_mode == "colocate": | |
| # llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model | |
| # llm_model.load_weights([(name, weight)]) | |
| # TODO: Support updating vLLM weights for NeuronPeftModel in server and colocate modes. | |
| # This will be implemented in a future PR as part of the vLLM integration work. |
| name = self._fix_param_name_to_vllm(name) | ||
|
|
||
| # TODO: Currently not supported, to implement asap in later PRs with vLLM integration. |
There was a problem hiding this comment.
Variable name is not used.
| name = self._fix_param_name_to_vllm(name) | |
| # TODO: Currently not supported, to implement asap in later PRs with vLLM integration. | |
| # TODO: Currently not supported, to implement asap in later PRs with vLLM integration. | |
| # name = self._fix_param_name_to_vllm(name) |
dacorvo
left a comment
There was a problem hiding this comment.
No more blockers, but will review in more details tomorrow. Copilot detected some issues that may be considered.
What does this PR do?
This PR adds partial support for GRPO.
It was broken down into smaller PRs:
optimum/neuron/accelerate#1042It adds the
NeuronGRPOTrainerwith a set of optimizations and modifications for the Torch XLA backend used to run things on Trainium instances. There are still core missing features:NeuronGRPOTrainer <-> vLLM